module Spec.Curve25519.Test

open Lib.IntTypes
open Lib.RawIntTypes
open Lib.Sequence
open Lib.ByteSequence
open Spec.Curve25519

(* ********************* *)
(* RFC 7748 Test Vectors *)
(* ********************* *)

let scalar1 = List.Tot.map u8_from_UInt8 [
  0xa5uy; 0x46uy; 0xe3uy; 0x6buy; 0xf0uy; 0x52uy; 0x7cuy; 0x9duy;
  0x3buy; 0x16uy; 0x15uy; 0x4buy; 0x82uy; 0x46uy; 0x5euy; 0xdduy;
  0x62uy; 0x14uy; 0x4cuy; 0x0auy; 0xc1uy; 0xfcuy; 0x5auy; 0x18uy;
  0x50uy; 0x6auy; 0x22uy; 0x44uy; 0xbauy; 0x44uy; 0x9auy; 0xc4uy
]

let scalar2 = List.Tot.map u8_from_UInt8 [
  0x4buy; 0x66uy; 0xe9uy; 0xd4uy; 0xd1uy; 0xb4uy; 0x67uy; 0x3cuy;
  0x5auy; 0xd2uy; 0x26uy; 0x91uy; 0x95uy; 0x7duy; 0x6auy; 0xf5uy;
  0xc1uy; 0x1buy; 0x64uy; 0x21uy; 0xe0uy; 0xeauy; 0x01uy; 0xd4uy;
  0x2cuy; 0xa4uy; 0x16uy; 0x9euy; 0x79uy; 0x18uy; 0xbauy; 0x0duy
]

let input1 = List.Tot.map u8_from_UInt8 [
  0xe6uy; 0xdbuy; 0x68uy; 0x67uy; 0x58uy; 0x30uy; 0x30uy; 0xdbuy;
  0x35uy; 0x94uy; 0xc1uy; 0xa4uy; 0x24uy; 0xb1uy; 0x5fuy; 0x7cuy;
  0x72uy; 0x66uy; 0x24uy; 0xecuy; 0x26uy; 0xb3uy; 0x35uy; 0x3buy;
  0x10uy; 0xa9uy; 0x03uy; 0xa6uy; 0xd0uy; 0xabuy; 0x1cuy; 0x4cuy
]

let input2 = List.Tot.map u8_from_UInt8 [
  0xe5uy; 0x21uy; 0x0fuy; 0x12uy; 0x78uy; 0x68uy; 0x11uy; 0xd3uy;
  0xf4uy; 0xb7uy; 0x95uy; 0x9duy; 0x05uy; 0x38uy; 0xaeuy; 0x2cuy;
  0x31uy; 0xdbuy; 0xe7uy; 0x10uy; 0x6fuy; 0xc0uy; 0x3cuy; 0x3euy;
  0xfcuy; 0x4cuy; 0xd5uy; 0x49uy; 0xc7uy; 0x15uy; 0xa4uy; 0x93uy
]

let expected1 = List.Tot.map u8_from_UInt8 [
  0xc3uy; 0xdauy; 0x55uy; 0x37uy; 0x9duy; 0xe9uy; 0xc6uy; 0x90uy;
  0x8euy; 0x94uy; 0xeauy; 0x4duy; 0xf2uy; 0x8duy; 0x08uy; 0x4fuy;
  0x32uy; 0xecuy; 0xcfuy; 0x03uy; 0x49uy; 0x1cuy; 0x71uy; 0xf7uy;
  0x54uy; 0xb4uy; 0x07uy; 0x55uy; 0x77uy; 0xa2uy; 0x85uy; 0x52uy
]

let expected2 = List.Tot.map u8_from_UInt8 [
  0x95uy; 0xcbuy; 0xdeuy; 0x94uy; 0x76uy; 0xe8uy; 0x90uy; 0x7duy;
  0x7auy; 0xaduy; 0xe4uy; 0x5cuy; 0xb4uy; 0xb8uy; 0x73uy; 0xf8uy;
  0x8buy; 0x59uy; 0x5auy; 0x68uy; 0x79uy; 0x9fuy; 0xa1uy; 0x52uy;
  0xe6uy; 0xf8uy; 0xf7uy; 0x64uy; 0x7auy; 0xacuy; 0x79uy; 0x57uy
]

let test () =
  assert_norm(List.Tot.length scalar1 = 32);
  assert_norm(List.Tot.length scalar2 = 32);
  assert_norm(List.Tot.length input1 = 32);
  assert_norm(List.Tot.length input2 = 32);
  assert_norm(List.Tot.length expected1 = 32);
  assert_norm(List.Tot.length expected2 = 32);
  let scalar1 = of_list scalar1 in
  let scalar2 = of_list scalar2 in
  let input1 = of_list input1 in
  let input2 = of_list input2 in
  let expected1 : lseq uint8 32 = of_list expected1 in
  let expected2 : lseq uint8 32 = of_list expected2 in
  let computed1 : lseq uint8 32 = scalarmult scalar1 input1 in
  let computed2 : lseq uint8 32 = scalarmult scalar2 input2 in
  let result1 : bool = for_all2 (fun a b -> uint_to_nat #U8 a = uint_to_nat #U8 b) computed1 expected1 in
  let result2 : bool = for_all2 (fun a b -> uint_to_nat #U8 a = uint_to_nat #U8 b) computed2 expected2 in
  IO.print_string   "Expected Shared Secret:";
  List.iter (fun a -> IO.print_string (UInt8.to_string (u8_to_UInt8 a))) (to_list expected1);
  IO.print_string "\nComputed Shared Secret:";
  List.iter (fun a -> IO.print_string (UInt8.to_string (u8_to_UInt8 a))) (to_list computed1);
  if result1 then   IO.print_string "\nSuccess!\n"
  else IO.print_string "\nFailure :(\n";
  IO.print_string   "Expected Shared Secret:";
  List.iter (fun a -> IO.print_string (UInt8.to_string (u8_to_UInt8 a))) (to_list expected2);
  IO.print_string "\nComputed Shared Secret:";
  List.iter (fun a -> IO.print_string (UInt8.to_string (u8_to_UInt8 a))) (to_list computed2);
  if result2 then   IO.print_string "\nSuccess!\n"
  else IO.print_string "\nFailure :(\n";
  result1 && result2
